from foolbox.models import PyTorchModel
import numpy as np
import torch
import copy
from tqdm import tqdm


def sample_hypersphere(n, d, qmc = False, seed = None, dtype = None):
    r"""Sample uniformly from a unit d-sphere.
    Args:
        d: The dimension of the hypersphere.
        n: The number of samples to return.
        qmc: If True, use QMC Sobol sampling (instead of i.i.d. uniform).
        seed: If provided, use as a seed for the RNG.
        device: The torch device.
        dtype:  The torch dtype.
    Returns:
        An  `n x d` tensor of uniform samples from from the d-hypersphere.
    Example:
        >>> sample_hypersphere(d=5, n=10)
    """
    dtype = torch.float if dtype is None else dtype
    rnd = torch.randn(n, d, dtype=dtype)
    samples = rnd / torch.norm(rnd, dim=-1, keepdim=True)
    samples = samples.cuda()
    return samples


def rand_ptb(ptb_org):
    ptb = ptb_org.clone()
    psh = ptb.shape
    ptb_vec = ptb.reshape(psh[0], -1)
    pvsh = ptb_vec.shape
    ptb_norm = torch.norm(ptb_vec, dim=-1, keepdim=True)
    r_ptb = sample_hypersphere(pvsh[0], pvsh[1]) * ptb_norm
    return r_ptb


def reshape_out(out, _to, _type, c_dim=3) :
    if type(out) == list :
        out = np.stack(out)
    if out.shape.index(c_dim) != _to.index(c_dim) :
        out = out.transpose((0, 2, 3, 1));
    if _type == np.uint8 :
        return (out*255.0).astype(np.uint8)
    else :
        return out


def att_helper(att_str):
    params = att_str.split("-");
    if len(params) == 2:
        att_str, num_steps = params[0], int(params[1])
    else:
        att_str = params[0]
    if att_str == "pgd" :
        return pgd_attack, num_steps
    elif att_str == "ifgsm":
        return ifgsm_attack, num_steps
    elif att_str == "l2" :
        return l2_attack, num_steps
    elif att_str == "cw":
        return cw_attack, None
    elif att_str == "aa":
        return aa_attack, None
    elif att_str == "mixup" :
        return mixup_attack, num_steps
    elif att_str == "rand" :
        return rand_attack, num_steps


def iter_foolbox(adv_attack, fmodel, dl, attack_params, epsilon, num_steps=1, c_dim=3):
    adv_imgs = []
    adv_logits = []
    adv_ys = []
    adv_dl = copy.deepcopy(dl);
    att_params_n = copy.deepcopy(attack_params);
    for n in tqdm(range(num_steps), desc="Collecting Adv {} step)".format(num_steps)) :
        adv_imgs.append(list())
        adv_logits.append(list())
        adv_ys.append(list())
        if "random_start" in attack_params :
            att_params_n["random_start"] = attack_params["random_start"] and n==0
        attack_n = adv_attack(**att_params_n);
        for it, (x, y) in enumerate(adv_dl):
            x = x.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True)
            _, adv_x, _ = attack_n(fmodel, x, y, epsilons=epsilon)
            logits = fmodel(adv_x);
            adv_logits[n].extend(logits.cpu().data.numpy())
            adv_ys[n].extend(y.cpu().data.numpy())
            adv_imgs[n].extend(adv_x.cpu().data.numpy())
        adv_imgs[n] = reshape_out(adv_imgs[n], _to=adv_dl.dataset.data.shape,
                                  _type=adv_dl.dataset.data.dtype, c_dim=c_dim);
        adv_dl.dataset.data = adv_imgs[n]
    return adv_imgs, adv_logits, adv_ys


def pgd_attack(model, dl, args, num_step, bounds=(-1., 1.)) :
    from foolbox.attacks import LinfProjectedGradientDescentAttack
    attack_params={
        'abs_stepsize': args.step_size,
        'random_start': args.random_start,
        'steps': 1,
    }
    fmodel = PyTorchModel(model, bounds=bounds, device="cuda")
    adv_attack = LinfProjectedGradientDescentAttack
    return iter_foolbox(adv_attack=adv_attack, fmodel=fmodel, attack_params=attack_params,
                        epsilon=args.epsilon, num_steps=num_step, dl = dl, c_dim=args.in_ch)


def ifgsm_attack(model, dl, args, num_step, bounds=(-1., 1.)):
    from foolbox.attacks import LinfBasicIterativeAttack
    attack_params = {
        'abs_stepsize': args.step_size,
        'random_start': args.random_start,
        'steps': 1,
    }
    fmodel = PyTorchModel(model, bounds=bounds, device="cuda")
    adv_attack = LinfBasicIterativeAttack
    return iter_foolbox(adv_attack=adv_attack, fmodel=fmodel, attack_params=attack_params,
                        epsilon = args.epsilon, num_steps=num_step, dl=dl, c_dim=args.in_ch)


def cw_attack(model, dl, args, num_step=None, bounds=(-1., 1.)) :
    from foolbox.attacks import L2CarliniWagnerAttack
    attack_params = {
        'binary_search_steps': args.binary_search_steps,
        'steps': args.max_iterations,
        'stepsize': args.lr,
        'initial_const': args.initial_const,
        'abort_early': True,
    }
    fmodel = PyTorchModel(model, bounds=bounds, device="cuda")
    adv_attack = L2CarliniWagnerAttack
    return iter_foolbox(adv_attack=adv_attack, fmodel=fmodel, attack_params=attack_params,
                        epsilon=args.epsilon, dl=dl, c_dim=args.in_ch)


def l2_attack(model, dl, args, num_step, bounds=(-1., 1.)) :
    from foolbox.attacks import L2BasicIterativeAttack
    attack_params = {
        'abs_stepsize': args.step_size,
        'random_start': args.random_start,
        'steps': 1,
    }
    fmodel = PyTorchModel(model, bounds=bounds, device="cuda")
    adv_attack = L2BasicIterativeAttack
    return iter_foolbox(adv_attack=adv_attack, fmodel=fmodel, attack_params=attack_params,
                            epsilon = args.epsilon, num_steps=num_step, dl=dl, c_dim=args.in_ch)


def aa_attack(model, dl, args) :
    attack_params = {
        'epsilon': args.epsilon
    }
    from torchattacks import AutoAttack
    adv_attack = AutoAttack(model.module,
                            eps=attack_params["epsilon"],
                            n_classes=args.num_labels)
    adv_imgs = []
    for it, (x, y) in enumerate(dl):
        x = x.cuda(non_blocking=True)
        y = y.cuda(non_blocking=True)
        # criterion = Misclassification(y)
        adv_x = adv_attack(x, y)
        adv_imgs.extend(adv_x.cpu().data.numpy())
    return reshape_out(adv_imgs, args.in_ch)


def rand_attack(model, dl, args, num_step=1, bounds=(-1., 1.)) :
    attack_params = {
        'epsilon': args.epsilon,
    }
    adv_imgs = []
    adv_logits = []
    adv_ys = []
    adv_dl = copy.deepcopy(dl);
    for n in tqdm(range(num_step), desc="Collecting Adv {} step)".format(num_step)) :
        adv_imgs.append(list())
        adv_logits.append(list())
        adv_ys.append(list())
        for it, (x, y) in enumerate(adv_dl):
            x = x.cuda(non_blocking=True)
            # y = y.cuda(non_blocking=True)
            # criterion = Misclassification(y)
            ptb = rand_ptb(x)
            ptbs = ptb.reshape(x.shape)
            adv_x = x + ptbs * attack_params["epsilon"]
            logits = model(adv_x);
            adv_logits[n].extend(logits.cpu().data.numpy())
            adv_ys[n].extend(y.cpu().data.numpy())
            adv_imgs[n].extend(adv_x.cpu().data.numpy())
        adv_imgs[n] = reshape_out(adv_imgs[n], _to=adv_dl.dataset.data.shape,
                                  _type=adv_dl.dataset.data.dtype, c_dim=args.in_ch);
        adv_dl.dataset.data = adv_imgs[n]
    return adv_imgs, adv_logits, adv_ys


def mixup_attack(model, dl, args, num_step=1) :
    attack_params = {
        'epsilon': args.epsilon,
    }
    adv_imgs = []
    # adv_logits = []
    # adv_ys = []
    adv_dl = copy.deepcopy(dl);
    shuffle_dl = copy.deepcopy(dl);
    dat_len= len(adv_dl.dataset.data)
    rand_idx = np.random.permutation(np.arange(dat_len))
    shuffle_dl.dataset.data = adv_dl.dataset.data[rand_idx]
    # for n in tqdm(range(num_step), desc="Collecting Adv {} step)".format(num_step)):
    # adv_imgs.append(list())
    # adv_logits.append(list())
    # adv_ys.append(list())
    for it, ((x, y), (shuffle_x, _)) in enumerate(zip(adv_dl, shuffle_dl)):
        x = x.cuda(non_blocking=True)
        shuffle_x = shuffle_x.cuda(non_blocking=True)
        # y = y.cuda(non_blocking=True)
        # criterion = Misclassification(y)
        ptb = shuffle_x - x
        adv_x = x + ptb * attack_params["epsilon"]
        # logits = model(adv_x);
        # adv_logits[n].extend(logits.cpu().data.numpy())
        # adv_ys[n].extend(y.cpu().data.numpy())
        adv_imgs.extend(adv_x.cpu().data.numpy())
    adv_imgs = reshape_out(adv_imgs, _to=adv_dl.dataset.data.shape,
                                _type=adv_dl.dataset.data.dtype, c_dim=args.in_ch)
    # adv_dl.dataset.data = adv_imgs[n]
    return [adv_imgs], None, None#adv_logits, adv_ys
